# coding=utf-8
import random
import math
from environment import Agent, Environment
from planner import RoutePlanner
from simulator import Simulator


class LearningAgent(Agent):
    """ An agent that learns to drive in the Smartcab world.
        This is the object you will be modifying. """ 

    def __init__(self, env, learning=False, epsilon=1.0, alpha=0.5):
        super(LearningAgent, self).__init__(env)     # Set the agent in the environment
        self.planner = RoutePlanner(self.env, self)  # Create a route planner
        self.valid_actions = self.env.valid_actions  # The set of valid actions

        # Set parameters of the learning agent
        self.learning = learning  # Whether the agent is expected to learn
        self.Q = dict()           # Create a Q-table which will be a dictionary of tuples
        self.epsilon = epsilon    # Random exploration factor
        self.alpha = alpha        # Learning factor

        ###########
        ## TO DO ##
        ###########
        self.trial_count = 0                    # Use as input for epsilon decay functions

        self.overall_states = 384            # Size of Q-Table = States Combination of all features = 2 * 4 * 4 * 4 * 3
        self.overall_state_action = 384 * 4  # Combination of all possible state + valid actions = 1536
        self.state_action_count = 0

    def reset(self, destination=None, testing=False):
        """ The reset function is called at the beginning of each trial.
            'testing' is set to True if testing trials are being used
            once training trials have completed. """

        # Select the destination as the new location to route to
        self.planner.route_to(destination)
        
        ########### 
        ## TO DO ##
        ###########
        # Update epsilon using a decay function of your choice

        # Linear: a = 0.05, 20 training trials
        def decay_linear(a):
            self.epsilon -= a

        # Polynomial: a = 20, 20 training trials
        def decay_frac(a):
            self.epsilon = 1.0 / a ** 2

        # Exponential: a = 0.995, 600 training trials
        def decay_exponential(a):
            self.epsilon = a ** self.trial_count

        def decay_exponential_e(a):
            self.epsilon = math.e ** (-a * self.trial_count)

        def decay_cosine(a):
            self.epsilon = math.cos(a * self.trial_count)

        # Hold epsilon at 1 during all training trials (in order to explore state-action combination more efficiently)
        def decay_step(total_trials):
            if self.trial_count > total_trials:
                self.epsilon = 0.0

        # Update additional class parameters as needed
        # If 'testing' is True, set epsilon and alpha to 0,
        # make sure retrieve from Q-Table instead of random walk.
        if testing:
            self.epsilon = 0.0
            self.alpha = 0.0
        else:
            self.trial_count += 1
            decay_exponential(0.995)
            # decay_linear(0.05)
            # decay_step(600)

        return None

    def build_state(self):
        """ The build_state function is called when the agent requests data from the 
            environment. The next waypoint, the intersection inputs, and the deadline 
            are all features available to the agent. """

        # Collect data about the environment
        waypoint = self.planner.next_waypoint()  # The next waypoint
        inputs = self.env.sense(self)            # Visual input - intersection light and traffic
        deadline = self.env.get_deadline(self)   # Remaining deadline

        ########### 
        ## TO DO ##
        ###########
        # Set 'state' as a tuple of relevant data for the agent
        # 选择5个基础特征，state是这5个特征所有状态的组合，应该一共有384种。（waypoint只会有3种状态）
        state = (inputs['light'], inputs['oncoming'], inputs['left'], inputs['right'], waypoint)

        return state

    def get_maxQ(self, state):
        """ The get_max_Q function is called when the agent is asked to find the
            maximum Q-value of all actions based on the 'state' the smartcab is in. """

        ########### 
        ## TO DO ##
        ###########
        # Calculate the maximum Q-value of all actions for a given state
        maxQ = max(self.Q[state].values())

        return maxQ 

    def createQ(self, state):
        """ The createQ function is called when a state is generated by the agent. """

        ########### 
        ## TO DO ##
        ###########
        # When learning, check if the 'state' is not in the Q-table
        # If it is not, create a new dictionary for that state
        #   Then, for each action available, set the initial Q-value to 0.0
        if self.learning:
            if state not in self.Q:
                self.Q[state] = {None: 0.0, 'forward': 0.0, 'left': 0.0, 'right': 0.0}  # Init default Q Score.
        return

    def choose_action(self, state):
        """ The choose_action function is called when the agent is asked to choose
            which action to take, based on the 'state' the smartcab is in. """

        # Set the agent state and default action
        self.state = state
        self.next_waypoint = self.planner.next_waypoint()
        action = None

        ########### 
        ## TO DO ##
        ###########
        # When not learning, choose a random action
        # When learning, choose a random action with 'epsilon' probability
        #   Otherwise, choose an action with the highest Q-value for the current state

        # 这里所说的按照epsilon的概率选择任意操作的实际意思是：
        # 首先，你有两个选项，opt1是任选一个操作，opt2一个是选择Q值最高的操作
        # 然后，epsilon的取值范围在[0,1]之内，并且呈现逐渐缩小的趋势，
        # 如果epsilon是0.7，那么opt1的概率就是0.7，opt2的概率就是0.3.
        if not self.learning:
            action = random.choice(self.valid_actions)  # Pick a random action from All 4 Valid Actions.
        else:
            if random.random() < self.epsilon:
                action = random.choice(self.valid_actions)  # Pick a random action from All 4 Valid Actions.
            else:
                # action = max(self.Q[state], key=self.Q[state].get)  # Get Key with max Q Score as action.
                # 需要考虑具有多个最大值的情况，应随机抽取
                max_Q = max(self.Q[state].values())
                max_candidate = [x for x in self.Q[state] if self.Q[state][x] == max_Q]
                action = random.choice(max_candidate)

        """Hard Coded Driving Logic For Fun"""
        # if inputs['light'] == 'red':
        #     if self.next_waypoint == 'right' and inputs['oncoming'] != 'left' and inputs['left'] != 'forward':
        #         action = self.next_waypoint
        #     else:
        #         action = None
        # else:
        #     if self.next_waypoint == 'left' and (inputs['oncoming'] == 'forward' or inputs['oncoming'] == 'right'):
        #         action = 'forward'
        #     else:
        #         action = self.next_waypoint

        return action

    def learn(self, state, action, reward):
        """ The learn function is called after the agent completes an action and
            receives an award. This function does not consider future rewards 
            when conducting learning. """

        ########### 
        ## TO DO ##
        ###########
        # When learning, implement the value iteration update rule
        #   Use only the learning rate 'alpha' (do not use the discount factor 'gamma')
        prev_Q = self.Q[state][action]
        self.Q[state][action] = prev_Q * (1 - self.alpha) + reward * self.alpha

        # 显示当前学习进度，包括<状态>的覆盖程度以及<状态-动作>的覆盖程度。
        if prev_Q == 0 and reward != 0:
            self.state_action_count += 1
        print 'Trial Count =', self.trial_count
        print 'Q-Table Size = {} / {}'.format(len(self.Q), self.overall_states)
        print 'Q-Table Non-zero Item Count = {} / {}'.format(self.state_action_count, self.overall_state_action)
        return

    def update(self):
        """ The update function is called when a time step is completed in the 
            environment for a given trial. This function will build the agent
            state, choose an action, receive a reward, and learn if enabled. """

        state = self.build_state()           # Get current state
        self.createQ(state)                  # Create 'state' in Q-table
        action = self.choose_action(state)   # Choose an action
        reward = self.env.act(self, action)  # Receive a reward
        self.learn(state, action, reward)    # Q-learn

        return
        

def run():
    """ Driving function for running the simulation. 
        Press ESC to close the simulation, or [SPACE] to pause the simulation. """

    ##############
    # Create the environment
    # Flags:
    #   verbose     - set to True to display additional output from the simulation
    #   num_dummies - discrete number of dummy agents in the environment, default is 100
    #   grid_size   - discrete number of intersections (columns, rows), default is (8, 6)
    env = Environment()
    
    ##############
    # Create the driving agent
    # Flags:
    #   learning   - set to True to force the driving agent to use Q-learning
    #    * epsilon - continuous value for the exploration factor, default is 1
    #    * alpha   - continuous value for the learning rate, default is 0.5
    agent = env.create_agent(LearningAgent, learning=True, alpha=0.5)
    
    ##############
    # Follow the driving agent
    # Flags:
    #   enforce_deadline - set to True to enforce a deadline metric
    env.set_primary_agent(agent, enforce_deadline=True)

    ##############
    # Create the simulation
    # Flags:
    #   update_delay - continuous time (in seconds) between actions, default is 2.0 seconds
    #   display      - set to False to disable the GUI if PyGame is enabled
    #   log_metrics  - set to True to log trial and simulation results to /logs
    #   optimized    - set to True to change the default log file name
    sim = Simulator(env, update_delay=0.001, display=False, log_metrics=True, optimized=True)
    
    ##############
    # Run the simulator
    # Flags:
    #   tolerance  - epsilon tolerance before beginning testing, default is 0.05 
    #   n_test     - discrete number of testing trials to perform, default is 0
    sim.run(n_test=10)


if __name__ == '__main__':
    run()
